# Copyright 2018 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from kfp.deprecated.dsl import Condition
from kfp.deprecated.dsl import ContainerOp
from kfp.deprecated.dsl import ExitHandler
from kfp.deprecated.dsl import OpsGroup
from kfp.deprecated.dsl import Pipeline
import kfp.deprecated.dsl as dsl


class TestOpsGroup(unittest.TestCase):

    def test_basic(self):
        """Test basic usage."""
        with Pipeline('somename') as p:
            self.assertEqual(1, len(p.groups))
            with OpsGroup(group_type='exit_handler'):
                op1 = ContainerOp(name='op1', image='image')
                with OpsGroup(group_type='branch'):
                    op2 = ContainerOp(name='op2', image='image')
                    op3 = ContainerOp(name='op3', image='image')
                with OpsGroup(group_type='loop'):
                    op4 = ContainerOp(name='op4', image='image')

        self.assertEqual(1, len(p.groups))
        self.assertEqual(1, len(p.groups[0].groups))
        exit_handler_group = p.groups[0].groups[0]
        self.assertEqual('exit_handler', exit_handler_group.type)
        self.assertEqual(2, len(exit_handler_group.groups))
        self.assertEqual(1, len(exit_handler_group.ops))
        self.assertEqual('op1', exit_handler_group.ops[0].name)

        branch_group = exit_handler_group.groups[0]
        self.assertFalse(branch_group.groups)
        self.assertCountEqual([x.name for x in branch_group.ops],
                              ['op2', 'op3'])

        loop_group = exit_handler_group.groups[1]
        self.assertFalse(loop_group.groups)
        self.assertCountEqual([x.name for x in loop_group.ops], ['op4'])

    def test_basic_recursive_opsgroups(self):
        """Test recursive opsgroups."""
        with Pipeline('somename') as p:
            self.assertEqual(1, len(p.groups))

            # When a graph opsgraph is called.
            graph_ops_group_one = dsl._ops_group.Graph('hello')
            graph_ops_group_one.__enter__()
            self.assertFalse(graph_ops_group_one.recursive_ref)
            self.assertEqual('graph-hello-1', graph_ops_group_one.name)

            # Another graph opsgraph is called with the same name
            # when the previous graph opsgraphs is not finished.
            graph_ops_group_two = dsl._ops_group.Graph('hello')
            graph_ops_group_two.__enter__()
            self.assertTrue(graph_ops_group_two.recursive_ref)
            self.assertEqual(graph_ops_group_one,
                             graph_ops_group_two.recursive_ref)

    def test_recursive_opsgroups_with_prefix_names(self):
        """Test recursive opsgroups."""
        with Pipeline('somename') as p:
            self.assertEqual(1, len(p.groups))

            # When a graph opsgraph is called.
            graph_ops_group_one = dsl._ops_group.Graph('foo_bar')
            graph_ops_group_one.__enter__()
            self.assertFalse(graph_ops_group_one.recursive_ref)
            self.assertEqual('graph-foo-bar-1', graph_ops_group_one.name)

            # Another graph opsgraph is called with the name as the prefix of the ops_group_one
            # when the previous graph opsgraphs is not finished.
            graph_ops_group_two = dsl._ops_group.Graph('foo')
            graph_ops_group_two.__enter__()
            self.assertFalse(graph_ops_group_two.recursive_ref)


class TestExitHandler(unittest.TestCase):

    def test_basic(self):
        """Test basic usage."""
        with Pipeline('somename') as p:
            exit_op = ContainerOp(name='exit', image='image')
            with ExitHandler(exit_op=exit_op):
                op1 = ContainerOp(name='op1', image='image')

        exit_handler = p.groups[0].groups[0]
        self.assertEqual('exit_handler', exit_handler.type)
        self.assertEqual('exit', exit_handler.exit_op.name)
        self.assertEqual(1, len(exit_handler.ops))
        self.assertEqual('op1', exit_handler.ops[0].name)

    def test_invalid_exit_op(self):
        with self.assertRaises(ValueError):
            with Pipeline('somename') as p:
                op1 = ContainerOp(name='op1', image='image')
                exit_op = ContainerOp(name='exit', image='image')
                exit_op.after(op1)
                with ExitHandler(exit_op=exit_op):
                    pass


class TestConditionOp(unittest.TestCase):

    def test_basic(self):
        with Pipeline('somename') as p:
            param1 = 'pizza'
            condition1 = Condition(param1 == 'pizza')
            self.assertEqual(condition1.name, None)
            with condition1:
                pass
            self.assertEqual(condition1.name, 'condition-1')

            condition2 = Condition(param1 == 'pizza', '[param1 is pizza]')
            self.assertEqual(condition2.name, '[param1 is pizza]')
            with condition2:
                pass
            self.assertEqual(condition2.name, 'condition-[param1 is pizza]-2')
